#!/usr/bin/env python3
# J27 — TOF SI Overlay (seconds)
# CONTROL: present-act, boolean/ordinal. Integer-radius shells; at tick t, shell r0+t fires.
# DETECTORS: annular radii (shell units), width in shells. Arrival tick = first tick that hits annulus.
# OVERLAY: convert native distances (shells) & times (ticks) to meters/seconds using manifest scales.
# INVARIANCE: native arrays are bit-identical before/after overlay; c_hat(SI) == c_hat(native)*dx/dt (within tol).

import argparse, csv, hashlib, json, math, os, sys
from datetime import datetime, timezone
from typing import Dict, List, Tuple

# ---------- utils ----------
def utc_ts() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")

def ensure_dirs(root: str, subs: List[str]) -> None:
    for s in subs:
        os.makedirs(os.path.join(root, s), exist_ok=True)

def write_text(path: str, text: str) -> None:
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

def json_dump(path: str, obj: dict) -> None:
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, sort_keys=True)

def sha256_bytes(b: bytes) -> str:
    import hashlib
    return hashlib.sha256(b).hexdigest()

# ---------- present-act TOF ----------
def build_shell_map(N: int, cx: int, cy: int) -> Dict[int, List[Tuple[int,int]]]:
    shells: Dict[int, List[Tuple[int,int]]] = {}
    for y in range(N):
        for x in range(N):
            dx = x - cx; dy = y - cy
            r = math.isqrt(dx*dx + dy*dy)
            shells.setdefault(r, []).append((x,y))
    return shells

def simulate_lightfront(H: int, r0: int, det_radii: List[int], det_w: int,
                        shells: Dict[int, List[Tuple[int,int]]]) -> List[dict]:
    arrivals = [{"detector_id": i, "radius_shells": r, "arrival_tick": None}
                for i, r in enumerate(det_radii)]
    for t in range(H):
        r_shell = r0 + t
        active = shells.get(r_shell, [])
        if not active:
            continue
        for i, r_i in enumerate(det_radii):
            if arrivals[i]["arrival_tick"] is None:
                if r_i <= r_shell <= (r_i + det_w - 1):
                    arrivals[i]["arrival_tick"] = t
    return arrivals

# ---------- diagnostics ----------
def linreg_slope(xs: List[float], ys: List[float]) -> Tuple[float, float]:
    # slope (y vs x) & r^2
    n = len(xs)
    if n < 2 or len(ys) != n:
        return float("nan"), float("nan")
    xb = sum(xs)/n; yb = sum(ys)/n
    num = sum((x-xb)*(y-yb) for x,y in zip(xs,ys))
    den = sum((x-xb)*(x-xb) for x in xs)
    if den == 0:
        return float("nan"), float("nan")
    b = num/den
    ss_tot = sum((y-yb)*(y-yb) for y in ys)
    ss_res = sum((y - (yb + b*(x-xb)))**2 for x,y in zip(xs,ys))
    r2 = 1.0 - (ss_res/ss_tot if ss_tot>0 else 0.0)
    return b, r2

# ---------- main ----------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--outdir", required=True)
    args = ap.parse_args()

    root = os.path.abspath(args.outdir)
    ensure_dirs(root, ["config","outputs/metrics","outputs/audits","outputs/run_info","logs"])

    # load manifest & persist
    with open(args.manifest, "r", encoding="utf-8") as f:
        M = json.load(f)
    man_path = os.path.join(root, "config", "manifest_j27.json")
    json_dump(man_path, M)

    # env log
    write_text(os.path.join(root, "logs", "env.txt"),
               "\n".join([f"utc={utc_ts()}",
                          f"os={os.name}",
                          f"cwd={os.getcwd()}",
                          f"python={sys.version.split()[0]}"]))

    # geometry & control
    N  = int(M["grid"]["N"])
    cx = int(M["grid"].get("cx", N//2))
    cy = int(M["grid"].get("cy", N//2))
    H  = int(M["H"])
    r0 = int(M["source"]["radius_shells"])
    step = int(M.get("step_per_tick", 1))
    if step != 1:
        raise SystemExit("This overlay expects step_per_tick=1 (one shell per tick).")

    det_r = [int(r) for r in M["detectors"]["radii_shells"]]
    det_w = int(M["detectors"].get("width_shells", 1))

    # scales (SI overlay)
    dx_m = float(M["scales"].get("dx_m_per_shell", 1.0))
    dt_s = float(M["scales"].get("dt_s_per_tick", 1.0))

    # predicted c (fallback if hinge absent)
    hinge = M.get("hinge", {})
    Tstar = hinge.get("T_star_plus1", None)
    R_eff = hinge.get("R_eff_m", None)
    circ  = hinge.get("circ_factor", "2pi")
    if Tstar is not None and R_eff is not None:
        Lsurf = (2.0*math.pi if str(circ).lower()=="2pi" else math.pi) * float(R_eff)
        c_pred_si = Lsurf / float(Tstar)
        c_pred_src = "hinge(Lsurf/T*)"
    else:
        c_pred_si = dx_m / dt_s  # step_per_tick=1
        c_pred_src = "fallback(dx_m/dt_s)"

    # simulate (present-act)
    shells = build_shell_map(N, cx, cy)
    arrivals = simulate_lightfront(H, r0, det_r, det_w, shells)

    # native series (distance in shells; time in ticks)
    native = [a for a in arrivals if a["arrival_tick"] is not None]
    d_shells = [(a["radius_shells"] - r0) for a in native]
    t_ticks  = [a["arrival_tick"] for a in native]

    # fit native speed (shells per tick)
    c_hat_native, r2_native = linreg_slope(t_ticks, d_shells)

    # SI overlay: meters/seconds
    d_m = [d * dx_m for d in d_shells]
    t_s = [t * dt_s for t in t_ticks]
    c_hat_si, r2_si = linreg_slope(t_s, d_m)

    # overlay invariance: c_hat_si vs c_hat_native*dx/dt
    c_hat_from_native = c_hat_native * (dx_m / dt_s)
    rel_diff = abs(c_hat_si - c_hat_from_native) / max(1e-12, abs(c_hat_from_native))

    # baseline arrays unchanged check (hash)
    native_blob = json.dumps({"d_shells": d_shells, "t_ticks": t_ticks}, sort_keys=True).encode("utf-8")
    native_hash_before = sha256_bytes(native_blob)
    native_hash_after = sha256_bytes(native_blob)

    # acceptance
    rel_err_c_max = float(M["acceptance"].get("rel_err_c_max", 0.02))
    r2_min        = float(M["acceptance"].get("r2_min", 0.98))
    overlay_tol   = float(M["acceptance"].get("overlay_rel_tol", 1e-9))

    # predicted speed comparison (in SI)
    c_hat_err = abs(c_hat_si - c_pred_si) / max(1e-12, abs(c_pred_si))

    passed = bool(
        (r2_native >= r2_min) and (r2_si >= r2_min) and
        (c_hat_err <= rel_err_c_max) and
        (rel_diff <= overlay_tol) and
        (native_hash_before == native_hash_after)
    )

    # write metrics
    mpath = os.path.join(root, "outputs", "metrics", "j27_tof_arrivals.csv")
    with open(mpath, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["detector_id","radius_shells","arrival_tick",
                    "distance_shells","time_ticks","distance_m","time_s"])
        for a, ds, tt, dm, ts in zip(native, d_shells, t_ticks, d_m, t_s):
            w.writerow([a["detector_id"], a["radius_shells"], a["arrival_tick"],
                        f"{ds:.9f}", f"{tt:.9f}", f"{dm:.9f}", f"{ts:.9f}"])

    # audit JSON
    audit = {
        "sim": "J27_tof_si_overlay",
        "grid": M["grid"], "H": H, "source": M["source"], "detectors": M["detectors"],
        "scales": M["scales"], "hinge_source": c_pred_src,
        "fit_native": {"c_hat_shells_per_tick": c_hat_native, "r2": r2_native},
        "fit_si": {"c_hat_m_per_s": c_hat_si, "r2": r2_si},
        "overlay": {
            "c_hat_from_native_m_per_s": c_hat_from_native,
            "overlay_rel_diff": rel_diff
        },
        "accept": {
            "rel_err_c_max": rel_err_c_max,
            "r2_min": r2_min,
            "overlay_rel_tol": overlay_tol
        },
        "pred": {"c_pred_si": c_pred_si},
        "native_hash_before": native_hash_before,
        "native_hash_after": native_hash_after,
        "pass": passed
    }
    json_dump(os.path.join(root, "outputs", "audits", "j27_audit.json"), audit)

    # result line
    result = ("J27 PASS={p} c_hat_si={cs:.6f} c_pred_si={cp:.6f} "
              "r2_native={rn:.4f} r2_si={rs:.4f} overlay_rel_diff={od:.2e}"
              .format(p=passed, cs=c_hat_si, cp=c_pred_si,
                      rn=r2_native, rs=r2_si, od=rel_diff))
    write_text(os.path.join(root, "outputs", "run_info", "result_line.txt"), result)
    print(result)

if __name__ == "__main__":
    main()
